{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Scaled Dot Product Attention (SDPA) Backward in cuDNN Frontend\n",
    "\n",
    "This operation computes gradient tensors for scaled dot product attention using the FlashAttention-2 algorithm as described in the paper FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning. The user is required to pass the stats tensor from the forward operation to the backward operation as input.\n",
    "\n",
    "The full documentation can be found in: [docs/operations/Attention.md#scaled-dot-product-attention-backward](https://github.com/NVIDIA/cudnn-frontend/blob/main/docs/operations/Attention.md#scaled-dot-product-attention-backward)\n",
    "\n",
    "The python test code for the full set of features can be found in: [test/python_fe/test_mhas.py](https://github.com/NVIDIA/cudnn-frontend/blob/main/test/python_fe/test_mhas.py)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/NVIDIA/cudnn-frontend/blob/main/samples/python/51_scaled_dot_product_attention_backward.ipynb)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Prerequisites and Setup\n",
    "This notebook requires an NVIDIA GPU A100 or newer. If running on Colab, go to Runtime → Change runtime type → Hardware accelerator and selct a GPU."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {},
   "outputs": [],
   "source": [
    "# get_ipython().system('nvidia-smi')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "metadata": {},
   "outputs": [],
   "source": [
    "# get_ipython().system('export CUDA_VERSION=\"12.3\"')\n",
    "# get_ipython().system('pip install nvidia-cudnn-cu12')\n",
    "# get_ipython().system('CUDNN_PATH=`pip show nvidia-cudnn-cu12  | grep Location | cut -d\":\" -f2 | xargs`/nvidia/cudnn pip install git+https://github.com/NVIDIA/cudnn-frontend.git')\n",
    "# get_ipython().system('pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {},
   "outputs": [],
   "source": [
    "import cudnn\n",
    "import torch\n",
    "import math\n",
    "\n",
    "torch.manual_seed(42)\n",
    "handle = cudnn.create_handle()\n",
    "\n",
    "assert torch.cuda.is_available()\n",
    "assert torch.cuda.get_device_capability()[0] >= 8, \"SDPA operation is only supported on SM80 architecture (Ampere) or above\"\n",
    "assert cudnn.backend_version() >= 8903, \"SDPA operation is only supported cuDNN version 8.9.3 or above\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Problem sizes\n",
    "\n",
    "For this example, we will use the problem size from the original GPT-2 paper where:\n",
    " - maximum sequence length = 1024\n",
    " - hidden dim = number of heads $\\times$ embedding dimension per head = 12 $\\times$ 64 = 768"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "metadata": {},
   "outputs": [],
   "source": [
    "b = 4    # batch size\n",
    "h = 12   # query number of heads\n",
    "s = 1024 # maximum sequence length\n",
    "d = 64   # embedding dimension per head\n",
    "\n",
    "attn_scale = 1.0 / math.sqrt(d)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Create the query, key, value, and output GPU tensors using PyTorch.\n",
    "\n",
    "**However for backwards computation, we also need to pass the stats tensor between the forward graph and the backward graph.**\n",
    "\n",
    "The stats tensor should have dims $(B, H, S, 1)$ and float32 datatype."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "metadata": {},
   "outputs": [],
   "source": [
    "# The tensors will have non-interleaved\n",
    "# BSHD (batch, sequence_length, num_head, dims_per_head) physical tensor layout\n",
    "# BHSD (batch, num_head, sequence_length, dims_per_head) logical tensor layout\n",
    "dims = (b, h, s, d)\n",
    "strides = (s * h * d, d, h * d, 1)\n",
    "\n",
    "q_gpu = torch.randn(b * s * h * d).half().cuda().as_strided(dims, strides)\n",
    "k_gpu = torch.randn(b * s * h * d).half().cuda().as_strided(dims, strides)\n",
    "v_gpu = torch.randn(b * s * h * d).half().cuda().as_strided(dims, strides)\n",
    "o_gpu = torch.empty(b * s * h * d).half().cuda().as_strided(dims, strides)\n",
    "stats_gpu = torch.empty(b, h, s, 1).float().cuda()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Also create the query, key, value, and output gradient tensors to be used for backwards computation."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "metadata": {},
   "outputs": [],
   "source": [
    "# note: torch 'like' preserves the strided layout\n",
    "dQ_gpu = torch.empty_like(q_gpu)\n",
    "dK_gpu = torch.empty_like(k_gpu)\n",
    "dV_gpu = torch.empty_like(v_gpu)\n",
    "dO_gpu = torch.randn_like(o_gpu)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Create the forward graph and build"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "metadata": {},
   "outputs": [],
   "source": [
    "graph_forward = cudnn.pygraph(\n",
    "    io_data_type=cudnn.data_type.HALF,\n",
    "    intermediate_data_type=cudnn.data_type.FLOAT,\n",
    "    compute_data_type=cudnn.data_type.FLOAT,\n",
    ")\n",
    "\n",
    "q_forward = graph_forward.tensor_like(q_gpu)\n",
    "k_forward = graph_forward.tensor_like(k_gpu)\n",
    "v_forward = graph_forward.tensor_like(v_gpu)\n",
    "\n",
    "# training mode in enabled with is_inference=False\n",
    "# causal mask is enabled\n",
    "o_forward, stats_forward = graph_forward.sdpa(\n",
    "    name=\"sdpa\",\n",
    "    q=q_forward, k=k_forward, v=v_forward,\n",
    "    is_inference=False,\n",
    "    attn_scale=attn_scale,\n",
    "    use_causal_mask=True,\n",
    ")\n",
    "\n",
    "o_forward.set_output(True).set_dim(o_gpu.size()).set_stride(o_gpu.stride())\n",
    "stats_forward.set_output(True).set_dim(stats_gpu.size()).set_stride(stats_gpu.stride())\n",
    "stats_forward.set_data_type(cudnn.data_type.FLOAT)\n",
    "\n",
    "graph_forward.validate()\n",
    "graph_forward.build_operation_graph()\n",
    "graph_forward.create_execution_plans([cudnn.heur_mode.A, cudnn.heur_mode.FALLBACK])\n",
    "graph_forward.check_support()\n",
    "graph_forward.build_plans()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Create the backward graph and build"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "metadata": {},
   "outputs": [],
   "source": [
    "graph_backward = cudnn.pygraph(\n",
    "    io_data_type=cudnn.data_type.HALF,\n",
    "    intermediate_data_type=cudnn.data_type.FLOAT,\n",
    "    compute_data_type=cudnn.data_type.FLOAT,\n",
    ")\n",
    "\n",
    "q_backward = graph_backward.tensor_like(q_gpu)\n",
    "k_backward = graph_backward.tensor_like(k_gpu)\n",
    "v_backward = graph_backward.tensor_like(v_gpu)\n",
    "o_backward = graph_backward.tensor_like(o_gpu)\n",
    "dO_backward = graph_backward.tensor_like(dO_gpu)\n",
    "stats_backward = graph_backward.tensor_like(stats_gpu)\n",
    "\n",
    "dQ_backward, dK_backward, dV_backward = graph_backward.sdpa_backward(\n",
    "    name=\"sdpa_backward\",\n",
    "    q=q_backward, k=k_backward, v=v_backward,\n",
    "    o=o_backward, dO=dO_backward, stats=stats_backward,\n",
    "    attn_scale=attn_scale,\n",
    "    use_causal_mask=True,\n",
    ")\n",
    "\n",
    "dQ_backward.set_output(True).set_dim(dQ_gpu.size()).set_stride(dQ_gpu.stride())\n",
    "dK_backward.set_output(True).set_dim(dK_gpu.size()).set_stride(dK_gpu.stride())\n",
    "dV_backward.set_output(True).set_dim(dV_gpu.size()).set_stride(dV_gpu.stride())\n",
    "\n",
    "graph_backward.validate()\n",
    "graph_backward.build_operation_graph()\n",
    "graph_backward.create_execution_plans([cudnn.heur_mode.A, cudnn.heur_mode.FALLBACK])\n",
    "graph_backward.check_support()\n",
    "graph_backward.build_plans()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Allocate workspace required to execute. We take the maximum since forward and backward are executed sequentially."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "metadata": {},
   "outputs": [],
   "source": [
    "workspace_size = max(\n",
    "    graph_forward.get_workspace_size(),\n",
    "    graph_backward.get_workspace_size(),\n",
    ")\n",
    "workspace = torch.empty(workspace_size, device=\"cuda\", dtype=torch.uint8)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Execute forward graph"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "metadata": {},
   "outputs": [],
   "source": [
    "variant_pack_forward = {\n",
    "    q_forward: q_gpu,\n",
    "    k_forward: k_gpu,\n",
    "    v_forward: v_gpu,\n",
    "    o_forward: o_gpu,\n",
    "    stats_forward: stats_gpu,\n",
    "}\n",
    "\n",
    "graph_forward.execute(variant_pack_forward, workspace)\n",
    "torch.cuda.synchronize()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Execute backward graph"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 49,
   "metadata": {},
   "outputs": [],
   "source": [
    "variant_pack_backward = {\n",
    "    q_backward: q_gpu,\n",
    "    k_backward: k_gpu,\n",
    "    v_backward: v_gpu,\n",
    "    o_backward: o_gpu,\n",
    "    dO_backward: dO_gpu,\n",
    "    stats_backward: stats_gpu,\n",
    "    dQ_backward: dQ_gpu,\n",
    "    dK_backward: dK_gpu,\n",
    "    dV_backward: dV_gpu,\n",
    "}\n",
    "\n",
    "graph_backward.execute(variant_pack_backward, workspace)\n",
    "torch.cuda.synchronize()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Test cuDNN's output against PyTorch's and check correctness"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 50,
   "metadata": {},
   "outputs": [],
   "source": [
    "q_ref = q_gpu.detach().float().requires_grad_()\n",
    "k_ref = k_gpu.detach().float().requires_grad_()\n",
    "v_ref = v_gpu.detach().float().requires_grad_()\n",
    "dO_ref = dO_gpu.detach().float()\n",
    "\n",
    "o_ref = torch.nn.functional.scaled_dot_product_attention(q_ref, k_ref, v_ref, is_causal=True, scale=attn_scale)\n",
    "torch.testing.assert_close(o_ref, o_gpu.float(), atol=5e-3, rtol=3e-3)\n",
    "\n",
    "dQ_ref, dK_ref, dV_ref = torch.autograd.grad(outputs=[o_ref], inputs=[q_ref, k_ref, v_ref], grad_outputs=[dO_ref])\n",
    "torch.testing.assert_close(dQ_ref, dQ_gpu.float(), atol=5e-3, rtol=3e-3)\n",
    "torch.testing.assert_close(dK_ref, dK_gpu.float(), atol=5e-3, rtol=3e-3)\n",
    "torch.testing.assert_close(dV_ref, dV_gpu.float(), atol=5e-3, rtol=3e-3)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
